# coding: utf-8
# @Author: cyl
# @File: service_redis.py
# @Time: 2023/02/20 22:14:31
import time
import functools
from typing import Union
from redis import StrictRedis
from redis.exceptions import ConnectionError

REDIS_MAX_RETRY_TIMES: int  = 3
REDIS_WAIT_TIME: int  = 1


def retry(max_retry_times=1, wait_time=0.1):
    def _(func):
        @functools.wraps(func)
        def wrap_func(*l, **kwargs):
            for i in range(max_retry_times):
                try:
                    return func(*l, **kwargs)
                except ConnectionError as conn_error:
                    # you can 自定义打印日志信息
                    if i == (max_retry_times - 1):
                        raise Exception(
                            f"redis client command `{func.__name__}` error!")
                except Exception as e:
                    # you can 自定义打印日志信息
                    if i == (max_retry_times - 1):
                        raise Exception(
                            f"redis client command `{func.__name__}` error!")
                time.sleep(wait_time)
        return wrap_func
    return _


class RedisClient(object):
    """redis client"""
    def __init__(self, config: dict) -> None:
        self._conn = None
        self.config = config

    def create(self):
        """create redis connection"""
        if self._conn is None:
            self._conn = StrictRedis(**self.config)

    def execute(self, cmd, *vargs):
        """The entrypoint to execute the command"""
        self.create()
        try:
            if hasattr(self, '_%s' % cmd):
                return getattr(self, '_%s' % cmd)(*vargs)
            else:
                return getattr(self._conn, cmd)(*vargs)
        except Exception:
            self.destroy()
            raise

    def destroy(self):
        """destory redis connection"""
        self._conn = None

    def decode(self, data: bytes):
        """bytes -> str"""
        if data is None:
            return None
        return str(data, encoding='utf-8')

    def decode_dict(self, data: dict):
        """Dict[bytes] -> Dict[str]"""
        if data is None:
            return None
        return {k.decode('utf-8'): v.decode('utf-8') for k, v in data.items()}

    def decode_iterable(self, data: Union[list, set, tuple]):
        """
        List[bytes] -> List[str]
        Tuple[bytes] -> Tuple[str]
        Set[bytes] -> Set[str]
        """
        if data is None:
            return None
        if isinstance(data, list):
            return list(map(self.decode, data))
        elif isinstance(data, set):
            return set(map(self.decode, data))
        elif isinstance(data, tuple):
            return tuple(map(self.decode, data))
        else:
            return list(map(self.decode, data))

    @retry(max_retry_times=REDIS_MAX_RETRY_TIMES, wait_time=REDIS_WAIT_TIME)
    def _set(self, name, value):
        return self._conn.set(name, value)

    @retry(max_retry_times=REDIS_MAX_RETRY_TIMES, wait_time=REDIS_WAIT_TIME)
    def _get(self, name):
        return self.decode(self._conn.get(name))

    @retry(max_retry_times=REDIS_MAX_RETRY_TIMES, wait_time=REDIS_WAIT_TIME)
    def _setnx(self, name, value):
        return self._conn.setnx(name, value)

    @retry(max_retry_times=REDIS_MAX_RETRY_TIMES, wait_time=REDIS_WAIT_TIME)
    def _lpush(self, name, *values):
        return self._conn.lpush(name, *values)

    @retry(max_retry_times=REDIS_MAX_RETRY_TIMES, wait_time=REDIS_WAIT_TIME)
    def _lrange(self, name, start, end):
        return self.decode_iterable(self._conn.lrange(name, start, end))

    @retry(max_retry_times=REDIS_MAX_RETRY_TIMES, wait_time=REDIS_WAIT_TIME)
    def _incr(self, name, amount=1):
        return self._conn.incr(name, amount)

    _incrby = _incr

    @retry(max_retry_times=REDIS_MAX_RETRY_TIMES, wait_time=REDIS_WAIT_TIME)
    def _hgetall(self, name):
        return self.decode_dict(self._conn.hgetall(name))

    @retry(max_retry_times=REDIS_MAX_RETRY_TIMES, wait_time=REDIS_WAIT_TIME)
    def _sadd(self, name, *values):
        return self._conn.sadd(name, *values)

    @retry(max_retry_times=REDIS_MAX_RETRY_TIMES, wait_time=REDIS_WAIT_TIME)
    def _sismember(self, name, value):
        return self._conn.sismember(name, value)

    @retry(max_retry_times=REDIS_MAX_RETRY_TIMES, wait_time=REDIS_WAIT_TIME)
    def _smembers(self, name):
        return self.decode_iterable(self._conn.smembers(name))

    @retry(max_retry_times=REDIS_MAX_RETRY_TIMES, wait_time=REDIS_WAIT_TIME)
    def _srem(self, name, *values):
        return self._conn.srem(name, *values)

    @retry(max_retry_times=REDIS_MAX_RETRY_TIMES, wait_time=REDIS_WAIT_TIME)
    def _delete(self, *names):
        return self._conn.delete(*names)


if __name__ == "__main__":
    config = {
        "host": "127.0.0.1",
        "port": 6379,
        "db": 0,
    }
    redis_cli: RedisClient = RedisClient(config)
    print(redis_cli.execute("set", "key1", "val1"))
    print(redis_cli.execute("get", "key1"))
    print(redis_cli.execute("lpush", "alist", "1"))
    print(redis_cli.execute("lrange", "alist", 0, -1))
    print(redis_cli.execute("setnx", "key41", "val4"))
    print(redis_cli.execute("incr", "key333"))
    print(f'rpush: {redis_cli.execute("rpush", "alist", 10)}')
